import tensorflow as tf  # 导入 TF 库

w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))

w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))

w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

x = tf.random.normal([2, 28, 28])
x = tf.reshape(x, [-1, 28 * 28])
print(x.shape)

h1 = x @ w1 + tf.broadcast_to(b1, [x.shape[0], 256])
h1 = tf.nn.relu(h1)

h2 = h1 @ w2 + b2
h2 = tf.nn.relu(h2)

out = h2 @ w3 + b3
print(out)

#loss = tf.reduce_mean(loss)
